#!/usr/bin/env python3
#
# Copyright 2018 Google LLC
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import importlib
import sys

import cipherlist
import hexjson
import paths
import tvgen

verbose = False
path = paths.top / "test_vectors" / "ours"

def write_other(ciphers):
    if ciphers is None: ciphers=cipherlist.common_ciphers
    for cipher in ciphers:
        test = cipher.name().lower()
        fname = path / cipher.name() / "other.json"
        print(f"Writing: {fname}")
        testmod = importlib.import_module(f"test_{test}")
        hexjson.write_using_hex(fname, testmod.test_vectors(cipher))

def write_ours(ciphers):
    if ciphers is None: ciphers=cipherlist.our_test_ciphers
    for cipher in ciphers:
        tvgen.write_tests(cipher, path)

def write(ciphers):
    write_other(ciphers)
    write_ours(ciphers)

def check_other(ciphers):
    if ciphers is None: ciphers=cipherlist.common_ciphers
    for cipher in ciphers:
        test = cipher.name().lower()
        print(f"======== {test} ========")
        testmod = importlib.import_module(f"test_{test}")
        for r in testmod.test_vectors(cipher):
            tvgen.check_testvector(cipher, r, verbose)

def check(ciphers):
    if ciphers is None: ciphers=cipherlist.our_test_ciphers
    for cipher in ciphers:
        tvgen.check_tests(cipher, path, verbose)

def main():
    if len(sys.argv) not in [2, 3]:
        raise Exception("Bad args: {} use write or check".format(sys.argv[1:]))
    command = sys.argv[1]
    ciphers = None
    if len(sys.argv) == 3:
        ciphers = [c for c in cipherlist.all_ciphers if c.name() == sys.argv[2]]
        if not ciphers:
            raise Exception(f"No cipher known called {sys.argv[2]}")
    if command == "write_other":
        write_other(ciphers)
    elif command == "write_ours":
        write_ours(ciphers)
    elif command == "write":
        write(ciphers)
    elif command == "check_other":
        check_other(ciphers)
    elif command == "check":
        check(ciphers)
    else:
        raise Exception("Bad args: {} use write or check".format(sys.argv[1:]))

if __name__ == "__main__":
    main()
